import numpy as np
from Env import FiniteStateFiniteActionMDP
import matplotlib.pyplot as plt
import math

class Qlearning_genb_low:
    def __init__(self, mdp, c1, c2, total_episodes):
        self.mdp = mdp
        self.c1 = c1
        self.c2 = c2
        self.total_episodes = total_episodes
        self.Nswitch = 0

        self.tildeV_func = np.zeros((self.mdp.H+1, self.mdp.S),dtype = np.float32)
        self.tildeV_next = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.tildeV_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.tildeV2_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)

        self.global_Q = np.full((self.mdp.H, self.mdp.S, self.mdp.A), self.mdp.H, dtype=np.float32)
        self.tilde_Q = np.full((self.mdp.H, self.mdp.S, self.mdp.A), self.mdp.H, dtype=np.float32)
        for i in range(self.mdp.H):
            self.global_Q[i,:,:] = self.mdp.H - i
            self.tilde_Q[i,:,:] = self.mdp.H - i

        self.N = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.int32)
        self.n = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.int32)
        self.beta = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)

        self.regret = []
        self.globalcost = []

    def run_episode(self):
        # Get the policy (actions for all states and steps)
        actions_policy = self.choose_action()
        state = self.mdp.reset()
        state_init = state
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))  # To store rewards for each state-step pair

        for step in range(self.mdp.H):
            # Select the action based on the agent's policy
            action = np.argmax(actions_policy[step, state])

            next_state, reward = self.mdp.step(action)

            # Increment visit count for the current state-action pair
            self.n[step, state, action] = 1

            self.tildeV_next[step, state, action] = self.tildeV_func[step+1, next_state]
            self.tildeV_sum[step, state, action] += self.tildeV_func[step+1, next_state]
            self.tildeV2_sum[step, state, action] += self.tildeV_func[step+1, next_state]**2
            
            # Store the received reward
            rewards[step, state, action] = reward
            state = next_state
        return rewards, state_init
    
    def choose_action(self):
        actions = np.zeros([self.mdp.H, self.mdp.S, self.mdp.A])

        for step in range(self.mdp.H):
            for state in range(self.mdp.S):
                best_action = np.argmax(self.global_Q[step, state])
                actions[step, state, best_action] = 1

        return actions
    
    def update_tildeQ(self, rewards):
        H = self.mdp.H
        for h in range(H):
            for s in range(self.mdp.S):
                for a in range(self.mdp.A):
                    if self.n[h, s, a] == 0:
                        continue
                    else:
                        update = self.thresholdset()
                        self.N[h, s, a] += 1
                        N_h_k = self.N[h, s, a]
                        step_size = (H + 1) / (H + N_h_k)
                        ucb_bonus = self.c1 * (H - h - 1) * np.sqrt(H / N_h_k)
                        sigma2_v = max(self.tildeV2_sum[h,s,a]/self.N[h, s, a] - (self.tildeV_sum[h,s,a]/self.N[h, s, a])**2,0)
                        betanew = min(self.c2*(np.sqrt(H*(sigma2_v+H)/self.N[h, s, a])+np.sqrt(H**7*self.mdp.S*self.mdp.A)/self.N[h, s, a]), 2*ucb_bonus)
                        bonus = (betanew - (1-step_size)*self.beta[h,s,a]) / (2*step_size)
                        self.beta[h,s,a] = betanew
                        self.tilde_Q[h, s, a] = (1-step_size) * self.tilde_Q[h, s, a] + \
                            step_size * (rewards[h, s, a] + self.tildeV_next[h, s, a] + bonus)
                        if self.N[h,s,a] in update:
                            self.global_Q[h,s,a] = self.tilde_Q[h,s,a]
        self.n.fill(0)

    def thresholdset(self):
        H = self.mdp.H
        eta = 1 + 1/ (2* H * (H+1))
        def tau(r):
            return math.ceil(eta ** r)    
        updateset = []       
        rstar = math.ceil(np.log(10*H^2)/np.log(eta))
        r = int(np.log(self.total_episodes)/np.log(eta))
        firstpart = list(range(1, tau(rstar) + 1))
        
        if r > rstar:
            secondpart = [tau(i) for i in range(rstar+1, r + 1)]
            updateset = firstpart +secondpart
        else:
            updateset = firstpart
        return updateset

                    
    def learn(self):
        # cummulative regret per-agent
        self.regret_cum = 0
        best_value , best_policy, best_Q = self.mdp.best_gen()

        # Initialize a structure to store rewards (deterministic reward)
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))
        for h in range(self.mdp.H):
            for s in range(self.mdp.S):
                self.tildeV_func[h,s] = max(self.tilde_Q[h, s, :])
        actions_policy = self.choose_action()
        self.n_switch = actions_policy

        for episode in range(self.total_episodes):
            run_reward, state_init = self.run_episode()
            value = self.mdp.value_gen(actions_policy)
            self.regret_cum = self.regret_cum + best_value[state_init] - value[state_init]
            self.regret.append(self.regret_cum)

            for h in range(self.mdp.H):
                for s in range(self.mdp.S):
                    a = np.argmax(actions_policy[h, s])
                    if rewards[h, s, a] == 0:
                        rewards[h, s, a] =run_reward[h,s,a]
            self.update_tildeQ(rewards)
            actions_policy = self.choose_action()
            for h in range(self.mdp.H):
                for s in range(self.mdp.S):
                    self.tildeV_func[h,s] = min(self.mdp.H-h, max(self.tilde_Q[h, s, :]))
            if not np.array_equal(self.n_switch, actions_policy):
                self.Nswitch += 1
            self.globalcost.append(self.Nswitch)
            self.n_switch = actions_policy
        return best_Q, self.global_Q